From: Taku Kudo Date: Fri, 5 Aug 2022 07:34:44 +0000 (+0900) Subject: support slice in pieces/nbests objects X-Git-Tag: archive/raspbian/0.1.97-3+rpi1^2~9 X-Git-Url: https://dgit.raspbian.org/%22http://www.example.com/cgi/%22/%22http:/www.example.com/cgi/%22?a=commitdiff_plain;h=487d7fc1b424d4d7c0a33be5c02cc6dfe1e689fd;p=sentencepiece.git support slice in pieces/nbests objects Signed-off-by: Kentaro Hayashi Gbp-Pq: Name 0019-support-slice-in-pieces-nbests-objects.patch --- diff --git a/python/src/sentencepiece/__init__.py b/python/src/sentencepiece/__init__.py index ce9d60d..cf06830 100644 --- a/python/src/sentencepiece/__init__.py +++ b/python/src/sentencepiece/__init__.py @@ -145,6 +145,10 @@ class ImmutableSentencePieceText(object): return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('piece index is out of range') return self.proto._pieces(index) @@ -202,6 +206,10 @@ class ImmutableNBestSentencePieceText(object): return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('nbests index is out of range') return self.proto._nbests(index) diff --git a/python/src/sentencepiece/sentencepiece.i b/python/src/sentencepiece/sentencepiece.i index e22f763..2ac68a8 100644 --- a/python/src/sentencepiece/sentencepiece.i +++ b/python/src/sentencepiece/sentencepiece.i @@ -1293,6 +1293,10 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._pieces(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('piece index is out of range') return self.proto._pieces(index) @@ -1336,6 +1340,10 @@ inline void InitNumThreads(const std::vector &ins, int *num_threads) { return self.len def __getitem__(self, index): + if isinstance(index, slice): + return [self.proto._nbests(i) for i in range(self.len)][index.start:index.stop:index.step] + if index < 0: + index = index + self.len if index < 0 or index >= self.len: raise IndexError('nbests index is out of range') return self.proto._nbests(index) diff --git a/python/test/sentencepiece_test.py b/python/test/sentencepiece_test.py index 6cbe077..92327ac 100755 --- a/python/test/sentencepiece_test.py +++ b/python/test/sentencepiece_test.py @@ -395,6 +395,10 @@ class TestSentencepieceProcessor(unittest.TestCase): self.assertEqual( self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text) + # slice + self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces))) + self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests))) + # Japanese offset s1 = self.jasp_.EncodeAsImmutableProto('吾輩は猫である。Hello world. ABC 123') surfaces1 = [s1.text[x.begin:x.end] for x in s1.pieces]